{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification on CIFAR and ImageNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# check whether run in Colab\n",
    "root = \".\"\n",
    "if \"google.colab\" in sys.modules:\n",
    "    print(\"Running in Colab.\")\n",
    "    !pip3 install matplotlib\n",
    "    !pip3 install einops==0.3.0\n",
    "    !pip3 install timm==0.4.9\n",
    "    !git clone https://github.com/xxxnell/how-do-vits-work.git\n",
    "    root = \"./how-do-vits-work\"\n",
    "    sys.path.append(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import yaml\n",
    "import copy\n",
    "from pathlib import Path\n",
    "import datetime\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import models\n",
    "import ops.trains as trains\n",
    "import ops.tests as tests\n",
    "import ops.datasets as datasets\n",
    "import ops.schedulers as schedulers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# config_path = \"%s/configs/cifar10_vit.yaml\" % root\n",
    "config_path = \"%s/configs/cifar100_vit.yaml\" % root\n",
    "# config_path = \"%s/configs/imagenet_vit.yaml\" % root\n",
    "\n",
    "with open(config_path) as f:\n",
    "    args = yaml.load(f)\n",
    "    print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "train_args = copy.deepcopy(args).get(\"train\")\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "model_args = copy.deepcopy(args).get(\"model\")\n",
    "optim_args = copy.deepcopy(args).get(\"optim\")\n",
    "env_args = copy.deepcopy(args).get(\"env\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)\n",
    "dataset_name = dataset_args[\"name\"]\n",
    "num_classes = len(dataset_train.classes)\n",
    "\n",
    "dataset_train = DataLoader(dataset_train, \n",
    "                           shuffle=True, \n",
    "                           num_workers=train_args.get(\"num_workers\", 4), \n",
    "                           batch_size=train_args.get(\"batch_size\", 128))\n",
    "dataset_test = DataLoader(dataset_test, \n",
    "                          num_workers=val_args.get(\"num_workers\", 4), \n",
    "                          batch_size=val_args.get(\"batch_size\", 128))\n",
    "\n",
    "print(\"Train: %s, Test: %s, Classes: %s\" % (\n",
    "    len(dataset_train.dataset), \n",
    "    len(dataset_test.dataset), \n",
    "    num_classes\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use provided models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# ResNet\n",
    "# name = \"resnet_dnn_50\"\n",
    "# name = \"resnet_dnn_101\"\n",
    "\n",
    "# ViT\n",
    "name = \"vit_ti\"\n",
    "# name = \"vit_s\"\n",
    "\n",
    "vit_kwargs = {  # for CIFAR\n",
    "    \"image_size\": 32, \n",
    "    \"patch_size\": 2,\n",
    "}\n",
    "\n",
    "model = models.get_model(name, num_classes=num_classes, \n",
    "                         stem=model_args.get(\"stem\", False), **vit_kwargs)\n",
    "# models.load(model, dataset_name, uid=current_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or use `timm`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "\n",
    "model = timm.models.vision_transformer.VisionTransformer(\n",
    "    img_size=32, patch_size=2, num_classes=num_classes,  # for CIFAR\n",
    "    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # ViT-Ti\n",
    ")\n",
    "model.name = \"vit_ti\"\n",
    "models.stats(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Parallelize the given `moodel` by splitting the input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = model.name\n",
    "model = nn.DataParallel(model)\n",
    "model.name = name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a TensorBoard writer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_time = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "log_dir = os.path.join(\"runs\", dataset_name, model.name, current_time)\n",
    "writer = SummaryWriter(log_dir)\n",
    "\n",
    "with open(\"%s/config.yaml\" % log_dir, \"w\") as f:\n",
    "    yaml.dump(args, f)\n",
    "with open(\"%s/model.log\" % log_dir, \"w\") as f:\n",
    "    f.write(repr(model))\n",
    "\n",
    "print(\"Create TensorBoard log dir: \", log_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "optimizer, train_scheduler = trains.get_optimizer(model, **optim_args)\n",
    "warmup_scheduler = schedulers.WarmupScheduler(optimizer, len(dataset_train) * train_args.get(\"warmup_epochs\", 0))\n",
    "\n",
    "trains.train(model, optimizer,\n",
    "             dataset_train, dataset_test,\n",
    "             train_scheduler, warmup_scheduler,\n",
    "             train_args, val_args, gpu,\n",
    "             writer, \n",
    "             snapshot=-1, dataset_name=dataset_name, uid=current_time)  # Set `snapshot=N` to save snapshots every N epochs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models.save(model, dataset_name, current_time, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_list = []\n",
    "for n_ff in [1]:\n",
    "    print(\"N: %s, \" % n_ff, end=\"\")\n",
    "    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)\n",
    "    metrics_list.append([n_ff, *metrics])\n",
    "\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s.csv\" % (dataset_name, model.name, current_time))\n",
    "tests.save_metrics(metrics_dir, metrics_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
